import random

# 设置输入和输出文件路径
input_file = './all_image.txt'
train_file = './train.txt'
test_file = './test.txt'
val_file = './val.txt'

# 定义分配比例
train_ratio = 0.8
val_ratio = 0.2
test_ratio = 0.001

# 确保比例之和为1
# print(int(train_ratio + test_ratio + val_ratio))
# assert train_ratio + test_ratio + val_ratio == 1, "Ratios must sum to 1"

# 读取所有图片路径
with open(input_file, 'r') as f:
    all_image_paths = [line.strip() for line in f.readlines()]

# 随机打乱图片路径
random.shuffle(all_image_paths)

# 计算每个文件的图片数量
num_train = int(train_ratio * len(all_image_paths))
num_val = int(val_ratio * len(all_image_paths))
# num_test = int(test_ratio * len(all_image_paths))
num_test = len(all_image_paths) - num_train - num_val

print("train number : {}".format(num_train))
print("val number : {}".format(num_val))
print("test number : {}".format(num_test))

# 分配图片路径
train_paths = all_image_paths[:num_train]
# test_paths = all_image_paths[num_train:num_train + num_test]
val_paths = all_image_paths[num_train:num_train + num_val]
test_paths = all_image_paths[num_train + num_val:]

# 将图片路径写入对应的文件
with open(train_file, 'w') as f:
    for path in train_paths:
        f.write(path + '\n')

with open(test_file, 'w') as f:
    for path in test_paths:
        f.write(path + '\n')

with open(val_file, 'w') as f:
    for path in val_paths:
        f.write(path + '\n')

print(f"Images have been written to {train_file}, {test_file}, and {val_file}")